from DNN_printer import DNN_printer

import torch

from transformers import (
    RobertaTokenizer,
    RobertaConfig,
    RobertaModel
)


checkpoint = 'microsoft/codebert-base'
tokenizer = RobertaTokenizer.from_pretrained(checkpoint)
ast_tokenizer = RobertaTokenizer.from_pretrained(checkpoint)
roberta = RobertaModel.from_pretrained(checkpoint)
roberta_config = RobertaConfig.from_pretrained(checkpoint)
torch.randint(tokenizer.vocab_size, (8, 512))
DNN_printer(roberta, tuple([512]), 8, device='cuda', dtype=torch.LongTensor)
